{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "0b6b37c9",
   "metadata": {},
   "source": [
    "# 在SecretFlow中使用自定义DataBuilder（Tensorflow）"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0e57a23",
   "metadata": {},
   "source": [
    "The following codes are demos only. It's **NOT for production** due to system security concerns, please **DO NOT** use it directly in production."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8528a86e",
   "metadata": {},
   "source": [
    "本教程将展示下，怎样在SecretFlow的多方安全环境中，如何使用自定义DataBuilder模式加载数据，并训练模型。\n",
    "本教程将使用Flower数据集的图像分类任务来进行介绍，如何使用自定义DataBuilder完成联邦学习"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b03e55b2",
   "metadata": {},
   "source": [
    "## 环境设置"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c08ecd6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "812c6aea",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-04-17 15:18:33,602\tINFO worker.py:1538 -- Started a local Ray instance.\n"
     ]
    }
   ],
   "source": [
    "import secretflow as sf\n",
    "\n",
    "# Check the version of your SecretFlow\n",
    "print('The version of SecretFlow: {}'.format(sf.__version__))\n",
    "\n",
    "# In case you have a running secretflow runtime already.\n",
    "sf.shutdown()\n",
    "sf.init(['alice', 'bob', 'charlie'], address=\"local\", log_to_driver=False)\n",
    "alice, bob ,charlie = sf.PYU('alice'), sf.PYU('bob') , sf.PYU('charlie')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25c38ba4",
   "metadata": {},
   "source": [
    "## 接口介绍"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "509e5461",
   "metadata": {},
   "source": [
    "我们在SecretFlow的`FLModel`中支持了自定义DataBuilder的读取方式，可以方便用户根据需求更灵活的处理数据输入。\n",
    "下面我们以一个例子来展示下，如何使用自定义DataBuilder来进行联邦模型训练。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff157baf",
   "metadata": {},
   "source": [
    "使用DataBuilder的步骤：\n",
    "1. 使用单机版本引擎（tensorflow，pytorch）进行开发，得到Dataset的Builder函数。\n",
    "2. 将各方的Builder函数进行wrap，得到create_dataset_builder。*注：dataset_builder函数需要传入stage参数*\n",
    "3. 构造data_builder_dict [PYU,dataset_builder]\n",
    "4. 将得到的data_builder_dict传入`fit`函数的`dataset_builder`。同时x参数位置传入dataset_builder中需要的输入。（eg:本例中传入的输入是实际使用的图像路径）\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fad75fd6",
   "metadata": {},
   "source": [
    "在FLModel中使用DataBuilder需要预先定义`data_builder_dict`。需要能够返回`tf.dataset`和`steps_per_epoch`。而且各方返回的steps_per_epoch必须保持一致。\n",
    "```python\n",
    "data_builder_dict = \n",
    "        {\n",
    "            alice: create_alice_dataset_builder(\n",
    "                batch_size=32,\n",
    "            ), # create_alice_dataset_builder must return (Dataset, steps_per_epoch)\n",
    "            bob: create_bob_dataset_builder(\n",
    "                batch_size=32,\n",
    "            ), # create_bob_dataset_builder must return (Dataset, steps_per_epochstep_per_epochs)\n",
    "        }\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c55b0faf",
   "metadata": {},
   "source": [
    "## 下载数据"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07ffc09e",
   "metadata": {},
   "source": [
    "Flower数据集介绍：flower数据集是一个包含了5种花卉（郁金香、黄水仙、鸢尾花、百合、向日葵）共计4323张彩色图片的数据集。每种花卉都有多个角度和不同光照下的图片，每张图片的分辨率为320x240。这个数据集常用于图像分类和机器学习算法的训练与测试。数据集中每个类别的数量分别是：daisy（633），dandelion（898），rose（641），sunflower（699），tulip（852）  \n",
    "  \n",
    "下载地址: [http://download.tensorflow.org/example_images/flower_photos.tgz](http://download.tensorflow.org/example_images/flower_photos.tgz)\n",
    "<img alt=\"flower_dataset_demo.png\" src=\"resources/flower_dataset_demo.png\" width=\"600\">  \n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef1f5d14",
   "metadata": {},
   "source": [
    "### 下载数据并解压"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ff9720cd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading data from https://secretflow-data.oss-accelerate.aliyuncs.com/datasets/tf_flowers/flower_photos.tgz\n",
      "67588319/67588319 [==============================] - 1s 0us/step\n"
     ]
    }
   ],
   "source": [
    "import tempfile\n",
    "import tensorflow as tf\n",
    "\n",
    "_temp_dir = tempfile.mkdtemp()\n",
    "path_to_flower_dataset = tf.keras.utils.get_file(\n",
    "    \"flower_photos\",\n",
    "    \"https://secretflow-data.oss-accelerate.aliyuncs.com/datasets/tf_flowers/flower_photos.tgz\",\n",
    "    untar=True,\n",
    "    cache_dir=_temp_dir,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d6f1057",
   "metadata": {},
   "source": [
    "# 接下来我们开始构造自定义DataBuilder"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "83d98ab3",
   "metadata": {},
   "source": [
    "## 1. 使用单机引擎开发DataBuilder"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c81834d",
   "metadata": {},
   "source": [
    "我们在开发DataBuilder的时候可以自由的按照单机开发的逻辑即可。  \n",
    "目的是构建一个`tf.dataset`对象即可"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "dfe48fe7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 436 files belonging to 5 classes.\n",
      "Using 349 files for training.\n",
      "Using 87 files for validation.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-04-10 13:16:34.492390: E tensorflow/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "import tensorflow as tf\n",
    "\n",
    "img_height = 180\n",
    "img_width = 180\n",
    "batch_size = 32\n",
    "# In this example, we use the TensorFlow interface for development.\n",
    "data_set = tf.keras.utils.image_dataset_from_directory(\n",
    "    path_to_flower_dataset,\n",
    "    validation_split=0.2,\n",
    "    subset=\"both\",\n",
    "    seed=123,\n",
    "    image_size=(img_height, img_width),\n",
    "    batch_size=batch_size,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8106f107",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set = data_set[0]\n",
    "test_set = data_set[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "37dd53fa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'tensorflow.python.data.ops.dataset_ops.BatchDataset'> <class 'tensorflow.python.data.ops.dataset_ops.BatchDataset'>\n"
     ]
    }
   ],
   "source": [
    "print(type(train_set),type(test_set))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f4199b24",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x.shape = (32, 180, 180, 3)\n",
      "y.shape = (32,)\n"
     ]
    }
   ],
   "source": [
    "x,y = next(iter(train_set))\n",
    "print(f\"x.shape = {x.shape}\")\n",
    "print(f\"y.shape = {y.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "020e0c15",
   "metadata": {},
   "source": [
    "## 2.将开发完成的DataBuilder进行包装(wrap)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bfbb0e6d",
   "metadata": {},
   "source": [
    "我们开发好的DataBuilder在运行是需要分发到各个执行机器上去执行，为了序列化，我们需要把他们进行wrap。  \n",
    "需要注意的是：**FLModel要求传入的DataBuilder需要返回两个结果（data_set，steps_per_epoch）**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d36bee75",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_dataset_builder(\n",
    "        batch_size=32,\n",
    "    ):\n",
    "    def dataset_builder(folder_path, stage=\"train\"):\n",
    "        import math\n",
    "\n",
    "        import tensorflow as tf\n",
    "\n",
    "        img_height = 180\n",
    "        img_width = 180\n",
    "        data_set = tf.keras.utils.image_dataset_from_directory(\n",
    "            folder_path,\n",
    "            validation_split=0.2,\n",
    "            subset=\"both\",\n",
    "            seed=123,\n",
    "            image_size=(img_height, img_width),\n",
    "            batch_size=batch_size,\n",
    "        )\n",
    "        if stage == \"train\":\n",
    "            train_dataset = data_set[0]\n",
    "            train_step_per_epoch = math.ceil(\n",
    "                len(data_set[0].file_paths) / batch_size\n",
    "            )\n",
    "            return train_dataset, train_step_per_epoch\n",
    "        elif stage == \"eval\":\n",
    "            eval_dataset = data_set[1]\n",
    "            eval_step_per_epoch = math.ceil(\n",
    "                len(data_set[1].file_paths) / batch_size\n",
    "            )\n",
    "            return eval_dataset, eval_step_per_epoch\n",
    "\n",
    "    return dataset_builder"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "198128c5",
   "metadata": {},
   "source": [
    "## 3. 构建dataset_builder_dict"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56a73a04",
   "metadata": {},
   "source": [
    "在水平场景，我们各方处理数据的逻辑是一样的，所以只需要一个wrap后的DataBuilder构造方法即可。  \n",
    "接下来我们构建`dataset_builder_dict`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "c051b566",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_builder_dict = {\n",
    "    alice: create_dataset_builder(\n",
    "        batch_size=32,\n",
    "    ),\n",
    "    bob: create_dataset_builder(\n",
    "        batch_size=32,\n",
    "    ),\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8e36c7d",
   "metadata": {},
   "source": [
    "## 4.得到dataset_builder_dict后我们就可以传入模型进行使用了"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5602707d",
   "metadata": {},
   "source": [
    "# 接下来我们定义模型，并使用上面构造好的自定义数据进行训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "feea334e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_conv_flower_model(input_shape, num_classes, name='model'):\n",
    "    def create_model():\n",
    "        from tensorflow import keras\n",
    "        # Create model\n",
    "\n",
    "        model = keras.Sequential(\n",
    "            [\n",
    "                keras.Input(shape=input_shape),\n",
    "                tf.keras.layers.Rescaling(1.0 / 255),\n",
    "                tf.keras.layers.Conv2D(32, 3, activation='relu'),\n",
    "                tf.keras.layers.MaxPooling2D(),\n",
    "                tf.keras.layers.Conv2D(32, 3, activation='relu'),\n",
    "                tf.keras.layers.MaxPooling2D(),\n",
    "                tf.keras.layers.Conv2D(32, 3, activation='relu'),\n",
    "                tf.keras.layers.MaxPooling2D(),\n",
    "                tf.keras.layers.Flatten(),\n",
    "                tf.keras.layers.Dense(128, activation='relu'),\n",
    "                tf.keras.layers.Dense(num_classes),\n",
    "            ]\n",
    "        )\n",
    "        # Compile model\n",
    "        model.compile(\n",
    "            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "            optimizer='adam',\n",
    "            metrics=[\"accuracy\"],\n",
    "        )\n",
    "        return model\n",
    "\n",
    "    return create_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "60414cb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from secretflow.ml.nn import FLModel\n",
    "from secretflow.security.aggregation import SecureAggregator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "f368538f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:Create proxy actor <class 'secretflow.security.aggregation.secure_aggregator._Masker'> with party alice.\n",
      "INFO:root:Create proxy actor <class 'secretflow.security.aggregation.secure_aggregator._Masker'> with party bob.\n",
      "INFO:root:Create proxy actor <class 'secretflow.ml.nn.fl.backend.tensorflow.strategy.fed_avg_w.PYUFedAvgW'> with party alice.\n",
      "INFO:root:Create proxy actor <class 'secretflow.ml.nn.fl.backend.tensorflow.strategy.fed_avg_w.PYUFedAvgW'> with party bob.\n"
     ]
    }
   ],
   "source": [
    "device_list = [alice, bob]\n",
    "aggregator = SecureAggregator(charlie,[alice,bob])\n",
    "\n",
    "# prepare model\n",
    "num_classes = 5\n",
    "input_shape = (180, 180, 3)\n",
    "\n",
    "# keras model\n",
    "model = create_conv_flower_model(input_shape, num_classes)\n",
    "\n",
    "\n",
    "fed_model = FLModel(\n",
    "    device_list=device_list,\n",
    "    model=model,\n",
    "    aggregator=aggregator,\n",
    "    backend=\"tensorflow\",\n",
    "    strategy=\"fed_avg_w\",\n",
    "    random_seed=1234,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32d97af9",
   "metadata": {},
   "source": [
    "我们构造好的dataset builder的输入是图像数据集的路径，所以这里需要将输入的数据设置为一个`Dict`\n",
    "```python\n",
    "data = {\n",
    "    alice: folder_path_of_alice,\n",
    "    bob: folder_path_of_bob\n",
    "}\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "de4b659a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:FL Train Params: {'self': <secretflow.ml.nn.fl.fl_model.FLModel object at 0x7f7b7a28b8e0>, 'x': {alice: '../../public_dataset/datasets/flower_photos', bob: '../../public_dataset/datasets/flower_photos'}, 'y': None, 'batch_size': 32, 'batch_sampling_rate': None, 'epochs': 5, 'verbose': 1, 'callbacks': None, 'validation_data': {alice: '../../public_dataset/datasets/flower_photos', bob: '../../public_dataset/datasets/flower_photos'}, 'shuffle': False, 'class_weight': None, 'sample_weight': None, 'validation_freq': 1, 'aggregate_freq': 2, 'label_decoder': None, 'max_batch_size': 20000, 'prefetch_buffer_size': None, 'sampler_method': 'batch', 'random_seed': 1234, 'dp_spent_step_freq': 1, 'audit_log_dir': None, 'dataset_builder': {alice: <function create_dataset_builder.<locals>.dataset_builder at 0x7f7b7a2bb1f0>, bob: <function create_dataset_builder.<locals>.dataset_builder at 0x7f7b7a2bb0d0>}}\n",
      "32it [00:18,  1.71it/s, epoch: 1/5 -  loss:1.5339548587799072  accuracy:0.3142559826374054  val_loss:1.582740068435669  val_accuracy:0.2874999940395355 ]\n",
      "100%|██████████| 8/8 [00:05<00:00,  1.51it/s, epoch: 2/5 -  loss:1.4520319700241089  accuracy:0.36693549156188965  val_loss:1.3319271802902222  val_accuracy:0.40416666865348816 ]\n",
      "100%|██████████| 8/8 [00:05<00:00,  1.54it/s, epoch: 3/5 -  loss:1.2720597982406616  accuracy:0.45766130089759827  val_loss:1.3382091522216797  val_accuracy:0.47083333134651184 ]\n",
      "100%|██████████| 8/8 [00:05<00:00,  1.50it/s, epoch: 4/5 -  loss:1.229131817817688  accuracy:0.5040322542190552  val_loss:1.3033963441848755  val_accuracy:0.4375 ]\n",
      "100%|██████████| 8/8 [00:05<00:00,  1.59it/s, epoch: 5/5 -  loss:1.3306885957717896  accuracy:0.4301075339317322  val_loss:2.1492652893066406  val_accuracy:0.25833332538604736 ]\n"
     ]
    }
   ],
   "source": [
    "data={\n",
    "        alice: path_to_flower_dataset,\n",
    "        bob: path_to_flower_dataset,\n",
    "    }\n",
    "history = fed_model.fit(\n",
    "    data,\n",
    "    None,\n",
    "    validation_data=data,\n",
    "    epochs=5,\n",
    "    batch_size=32,\n",
    "    aggregate_freq=2,\n",
    "    sampler_method=\"batch\",\n",
    "    random_seed=1234,\n",
    "    dp_spent_step_freq=1,\n",
    "    dataset_builder=data_builder_dict,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54ff96a9",
   "metadata": {},
   "source": [
    "# 接下来，您可以使用自己的数据集来进行尝试"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
